"""
Test security fixes for various DoS vulnerabilities.
"""

import asyncio
from unittest.mock import patch

import pytest

from backend.blocks.code_extraction_block import CodeExtractionBlock
from backend.blocks.iteration import StepThroughItemsBlock
from backend.blocks.llm import AITextSummarizerBlock
from backend.blocks.text import ExtractTextInformationBlock
from backend.blocks.xml_parser import XMLParserBlock
from backend.util.file import store_media_file
from backend.util.type import MediaFileType


class TestCodeExtractionBlockSecurity:
    """Test ReDoS fixes in CodeExtractionBlock."""

    async def test_redos_protection(self):
        """Test that the regex patterns don't cause ReDoS."""
        block = CodeExtractionBlock()

        # Test with input that would previously cause ReDoS
        malicious_input = "```python" + " " * 10000  # Large spaces

        result = []
        async for output_name, output_data in block.run(
            CodeExtractionBlock.Input(text=malicious_input)
        ):
            result.append((output_name, output_data))

        # Should complete without hanging
        assert len(result) >= 1
        assert any(name == "remaining_text" for name, _ in result)


class TestAITextSummarizerBlockSecurity:
    """Test memory exhaustion fixes in AITextSummarizerBlock."""

    def test_split_text_limits(self):
        """Test that _split_text has proper limits."""
        # Test text size limit
        large_text = "a" * 2_000_000  # 2MB text
        result = AITextSummarizerBlock._split_text(large_text, 1000, 100)

        # Should be truncated to 1MB
        total_chars = sum(len(chunk) for chunk in result)
        assert total_chars <= 1_000_000 + 1000  # Allow for chunk boundary

        # Test chunk count limit
        result = AITextSummarizerBlock._split_text("word " * 10000, 10, 9)
        assert len(result) <= 100  # MAX_CHUNKS limit

        # Test parameter validation
        result = AITextSummarizerBlock._split_text(
            "test", 10, 15
        )  # overlap > max_tokens
        assert len(result) >= 1  # Should still work


class TestExtractTextInformationBlockSecurity:
    """Test ReDoS and memory exhaustion fixes in ExtractTextInformationBlock."""

    async def test_text_size_limits(self):
        """Test text size limits."""
        block = ExtractTextInformationBlock()

        # Test with large input
        large_text = "a" * 2_000_000  # 2MB

        results = []
        async for output_name, output_data in block.run(
            ExtractTextInformationBlock.Input(
                text=large_text, pattern=r"a+", find_all=True, group=0
            )
        ):
            results.append((output_name, output_data))

        # Should complete and have limits applied
        matched_results = [r for name, r in results if name == "matched_results"]
        if matched_results:
            assert len(matched_results[0]) <= 1000  # MAX_MATCHES limit

    async def test_dangerous_pattern_timeout(self):
        """Test timeout protection for dangerous patterns."""
        block = ExtractTextInformationBlock()

        # Test with potentially dangerous lookahead pattern
        test_input = "a" * 1000

        # This should complete quickly due to timeout protection
        start_time = asyncio.get_event_loop().time()
        results = []
        async for output_name, output_data in block.run(
            ExtractTextInformationBlock.Input(
                text=test_input, pattern=r"(?=.+)", find_all=True, group=0
            )
        ):
            results.append((output_name, output_data))

        end_time = asyncio.get_event_loop().time()
        # Should complete within reasonable time (much less than 5s timeout)
        assert (end_time - start_time) < 10

    async def test_redos_catastrophic_backtracking(self):
        """Test that ReDoS patterns with catastrophic backtracking are handled."""
        block = ExtractTextInformationBlock()

        # Pattern that causes catastrophic backtracking: (a+)+b
        # With input "aaaaaaaaaaaaaaaaaaaaaaaaaaaa" (no 'b'), this causes exponential time
        dangerous_pattern = r"(a+)+b"
        test_input = "a" * 30  # 30 'a's without a 'b' at the end

        # This should be handled by timeout protection or pattern detection
        start_time = asyncio.get_event_loop().time()
        results = []

        async for output_name, output_data in block.run(
            ExtractTextInformationBlock.Input(
                text=test_input, pattern=dangerous_pattern, find_all=True, group=0
            )
        ):
            results.append((output_name, output_data))

        end_time = asyncio.get_event_loop().time()
        elapsed = end_time - start_time

        # Should complete within timeout (6 seconds to be safe)
        # The current threading.Timer approach doesn't work, so this will likely fail
        # demonstrating the need for a fix
        assert elapsed < 6, f"Regex took {elapsed}s, timeout mechanism failed"

        # Should return empty results on timeout or no match
        matched_results = [r for name, r in results if name == "matched_results"]
        assert matched_results[0] == []  # No matches expected


class TestStepThroughItemsBlockSecurity:
    """Test iteration limits in StepThroughItemsBlock."""

    async def test_item_count_limits(self):
        """Test maximum item count limits."""
        block = StepThroughItemsBlock()

        # Test with too many items
        large_list = list(range(20000))  # Exceeds MAX_ITEMS (10000)

        with pytest.raises(ValueError, match="Too many items"):
            async for _ in block.run(StepThroughItemsBlock.Input(items=large_list)):
                pass

    async def test_string_size_limits(self):
        """Test string input size limits."""
        block = StepThroughItemsBlock()

        # Test with large JSON string
        large_string = '["item"]' * 200000  # Large JSON string

        with pytest.raises(ValueError, match="Input too large"):
            async for _ in block.run(
                StepThroughItemsBlock.Input(items_str=large_string)
            ):
                pass

    async def test_normal_iteration_works(self):
        """Test that normal iteration still works."""
        block = StepThroughItemsBlock()

        results = []
        async for output_name, output_data in block.run(
            StepThroughItemsBlock.Input(items=[1, 2, 3])
        ):
            results.append((output_name, output_data))

        # Should have 6 outputs (item, key for each of 3 items)
        assert len(results) == 6
        items = [data for name, data in results if name == "item"]
        assert items == [1, 2, 3]


class TestXMLParserBlockSecurity:
    """Test XML size limits in XMLParserBlock."""

    async def test_xml_size_limits(self):
        """Test XML input size limits."""
        block = XMLParserBlock()

        # Test with large XML - need to exceed 10MB limit
        # Each "<item>data</item>" is 17 chars, need ~620K items for >10MB
        large_xml = "<root>" + "<item>data</item>" * 620000 + "</root>"

        with pytest.raises(ValueError, match="XML too large"):
            async for _ in block.run(XMLParserBlock.Input(input_xml=large_xml)):
                pass


class TestStoreMediaFileSecurity:
    """Test file storage security limits."""

    @patch("backend.util.file.scan_content_safe")
    @patch("backend.util.file.get_cloud_storage_handler")
    async def test_file_size_limits(self, mock_cloud_storage, mock_scan):
        """Test file size limits."""
        # Mock cloud storage handler - get_cloud_storage_handler is async
        # but is_cloud_path and parse_cloud_path are sync methods
        from unittest.mock import MagicMock

        mock_handler = MagicMock()
        mock_handler.is_cloud_path.return_value = False

        # Make get_cloud_storage_handler an async function that returns the mock handler
        async def async_get_handler():
            return mock_handler

        mock_cloud_storage.side_effect = async_get_handler
        mock_scan.return_value = None

        # Test with large base64 content
        large_content = "a" * (200 * 1024 * 1024)  # 200MB
        large_data_uri = f"data:text/plain;base64,{large_content}"

        with pytest.raises(ValueError, match="File too large"):
            await store_media_file(
                graph_exec_id="test",
                file=MediaFileType(large_data_uri),
                user_id="test_user",
            )

    @patch("backend.util.file.Path")
    @patch("backend.util.file.scan_content_safe")
    @patch("backend.util.file.get_cloud_storage_handler")
    async def test_directory_size_limits(self, mock_cloud_storage, mock_scan, MockPath):
        """Test directory size limits."""
        from unittest.mock import MagicMock

        mock_handler = MagicMock()
        mock_handler.is_cloud_path.return_value = False

        async def async_get_handler():
            return mock_handler

        mock_cloud_storage.side_effect = async_get_handler
        mock_scan.return_value = None

        # Create mock path instance for the execution directory
        mock_path_instance = MagicMock()
        mock_path_instance.exists.return_value = True

        # Mock glob to return files that total > 1GB
        mock_file = MagicMock()
        mock_file.is_file.return_value = True
        mock_file.stat.return_value.st_size = 2 * 1024 * 1024 * 1024  # 2GB
        mock_path_instance.glob.return_value = [mock_file]

        # Make Path() return our mock
        MockPath.return_value = mock_path_instance

        # Should raise an error when directory size exceeds limit
        with pytest.raises(ValueError, match="Disk usage limit exceeded"):
            await store_media_file(
                graph_exec_id="test",
                file=MediaFileType(
                    "data:text/plain;base64,dGVzdA=="
                ),  # Small test file
                user_id="test_user",
            )
